// Copyright © 2024 Apple Inc.

#include <metal_integer>
#include <metal_math>

// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/ternary_ops.h"
#include "mlx/backend/metal/kernels/ternary.h"

#define instantiate_ternary_base(op, tname, type)                    \
  instantiate_kernel("v_" #op #tname, ternary_v, type, op, false, false, 1) \
  instantiate_kernel("v2_" #op #tname, ternary_v2, type, op, false, false)  \
  instantiate_kernel("vs_" #op #tname, ternary_v, type, op, false, true, 1) \
  instantiate_kernel("vs2_" #op #tname, ternary_v2, type, op, false, true)  \
  instantiate_kernel("sv_" #op #tname, ternary_v, type, op, true, false, 1) \
  instantiate_kernel("sv2_" #op #tname, ternary_v2, type, op, true, false)  \
  instantiate_kernel("gn2_" #op #tname, ternary_g, type, op, 2, int) \
  instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op, int) \
  instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op, int) \
  instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op, int) \
  instantiate_kernel("g1large_" #op #tname, ternary_g_nd1, type, op) \
  instantiate_kernel("g2large_" #op #tname, ternary_g_nd2, type, op) \
  instantiate_kernel("g3large_" #op #tname, ternary_g_nd3, type, op) \
  instantiate_kernel("gn4large_" #op #tname, ternary_g, type, op, 4) \

#define instantiate_ternary_all(op, tname, type)            \
  instantiate_kernel("vn_" #op #tname, ternary_v, type, op, false, false) \
  instantiate_kernel("vsn_" #op #tname, ternary_v, type, op, false, true) \
  instantiate_kernel("svn_" #op #tname, ternary_v, type, op, true, false) \
  instantiate_ternary_base(op, tname, type)

#define instantiate_ternary_types(op)               \
  instantiate_ternary_all(op, bool_, bool)          \
  instantiate_ternary_all(op, uint8, uint8_t)       \
  instantiate_ternary_all(op, uint16, uint16_t)     \
  instantiate_ternary_all(op, uint32, uint32_t)     \
  instantiate_ternary_base(op, uint64, uint64_t)    \
  instantiate_ternary_all(op, int8, int8_t)         \
  instantiate_ternary_all(op, int16, int16_t)       \
  instantiate_ternary_all(op, int32, int32_t)       \
  instantiate_ternary_base(op, int64, int64_t)      \
  instantiate_ternary_all(op, float16, half)        \
  instantiate_ternary_all(op, float32, float)       \
  instantiate_ternary_all(op, bfloat16, bfloat16_t) \
  instantiate_ternary_base(op, complex64, complex64_t) // clang-format on

instantiate_ternary_types(Select)
